"""
## LaVCa step 3: Generate captions for these optimal images using a Multimodal LLM (MLLM) for summarization by an LLM in the next step.
## For efficiency, captioning is performed on all images in advance.

python3 -m LaVCa.captioning_opt_images \
    --caption_model MiniCPM-Llama3-V-2_5 \
    --dataset_name OpenImages \
    --dataset_path ./data/OpenImages/frames_518x518px \
    --max_samples full \
    --device cuda:0
"""

from transformers import AutoModel, AutoTokenizer
import torch
import argparse
import os
from tqdm import tqdm
from PIL import Image
from utils.utils import load_frames

torch.manual_seed(42)

def main(args):
    dataset_name = args.dataset_name
    
    # Load the model
    if args.caption_model == "MiniCPM-Llama3-V-2_5": 
        model = AutoModel.from_pretrained('openbmb/MiniCPM-Llama3-V-2_5', trust_remote_code=True, torch_dtype=torch.float16)
        model = model.to(device=args.device)

        tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-Llama3-V-2_5', trust_remote_code=True)
        model.eval()
        prompt = "Describe the image briefly."
        msgs = [{'role': 'user', 'content': prompt}]


    dataset_name = args.dataset_name
    if args.max_samples == "full":
        break_point = 100000000
    else:
        break_point = int(args.max_samples)

    if dataset_name == "OpenImages":
        frames_all = load_frames(f"{args.dataset_path}", img_type="jpg")
        print(f"Number of directory: {len(frames_all)}")
        frames_all = {k: frames_all[k] for k in sorted(frames_all, key=lambda x: int(x.split('_')[1]))}
    
    # 刺激の読み込み
    if args.split:
        frames_all = {k: v for k, v in frames_all.items() if args.split[0] <= int(k.split('_')[1]) <= args.split[1]}
        print(f"Selected directories: {frames_all.keys()}")

    for dir_name, frame_paths in tqdm(frames_all.items()):
        stim_dir = os.path.join(args.dataset_path, dir_name)
        for frame_path in frame_paths:
            stim_basename = os.path.basename(frame_path).replace(".jpg", "").replace(".png", "")
            print(f"Now processing {dir_name}/{stim_basename}...")
            caption_file_path = os.path.join(stim_dir, f"caption_{stim_basename}_{args.caption_model}.txt")
            temp_file_path = os.path.join(stim_dir, f"temp_{stim_basename}_{args.caption_model}.tmp")
            
            # Skip if caption file already exists or if temp file exists
            if os.path.exists(caption_file_path):
                print(f"Caption for {stim_basename} already exists.")
                continue
            elif os.path.exists(temp_file_path):
                print(f"Caption for {stim_basename} is being processed.")
                continue
            
            # Create temp file to signal this frame is being processed
            open(temp_file_path, 'a').close()
            
            try:
                stim = Image.open(frame_path).convert("RGB")
                response = model.chat(
                    image=stim,
                    msgs=msgs,
                    tokenizer=tokenizer,
                    sampling=True, # if sampling=False, beam_search will be used by default
                    temperature=0.7,
                    new_max_tokens=50
                )
                print(response)  
            
                # Save captions to text file
                with open(caption_file_path, "w") as f:
                    f.write(response)
            finally:
                # Remove temp file after processing is done
                try:
                    os.remove(temp_file_path)
                except:
                    pass

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Add your arguments here
    parser.add_argument(
        "--caption_model",
        type=str,
        required=True,
        help="Name of the captioning model to use."
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_path",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--max_samples",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--device",
        type=str,
        required=True,
        help="Device to use."
    )
    parser.add_argument(
        "--split",
        type=int,
        nargs="*",
    )
    args = parser.parse_args()
    main(args)